{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To run this notebook, load it in a local Jupyter instance (`pip install jupyter`). You'll also need these dependencies:\n",
    "\n",
    "```\n",
    "pip install tf-nightly\n",
    "pip install google-cloud-storage\n",
    "pip install requests\n",
    "pip install google-api-python-client\n",
    "```\n",
    "\n",
    "You may also need to run this if you're not inside a google cloud VM:\n",
    "\n",
    "```\n",
    "gcloud auth application-default login\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You need to configure [OAuth](https://support.google.com/cloud/answer/6158849?hl=en). It's a complicated process, best described [here](https://github.com/googleapis/google-api-python-client/blob/master/docs/client-secrets.md). In the end you donwload the `client_secrets.json` file and use it below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "from astronet import tune\n",
    "from astronet import train\n",
    "\n",
    "config_name = 'final_alpha_1'\n",
    "\n",
    "tune.FLAGS = tune.parser.parse_args([\n",
    "  '--client_secrets', '../client_secrets.json',\n",
    "  '--model', 'AstroCnnModel',\n",
    "  '--config_name', config_name,\n",
    "  '--train_files', '',\n",
    "])\n",
    "train.FLAGS = train.parser.parse_args([\n",
    "  '--model', 'AstroCnnModel',\n",
    "  '--config_name', config_name,\n",
    "  '--train_files', '',\n",
    "])\n",
    "\n",
    "\n",
    "client = tune.initialize_client()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "pd.set_option('max_colwidth', 100)\n",
    "\n",
    "resp = client.projects().locations().studies().list(parent=tune.study_parent()).execute()\n",
    "studies = pd.DataFrame(resp['studies'])\n",
    "studies = studies.sort_values('createTime', ascending=False)\n",
    "studies.head(10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "study = studies['name'][0]\n",
    "study_id = '{}/studies/{}'.format(tune.study_parent(), study.split('/')[-1])\n",
    "print(study_id)\n",
    "resp = client.projects().locations().studies().trials().list(parent=study_id).execute()\n",
    "\n",
    "metrics_loss = []\n",
    "params = []\n",
    "trial_ids = []\n",
    "for trial in resp['trials']:\n",
    "  if 'finalMeasurement' not in trial:\n",
    "    continue\n",
    "    \n",
    "  if 'value' not in trial['finalMeasurement']['metrics'][0]:\n",
    "    continue\n",
    "\n",
    "  loss, = (m['value'] for m in trial['finalMeasurement']['metrics'] if m['metric'] == 'loss')  \n",
    "  \n",
    "  params.append(trial['parameters'])\n",
    "  metrics_loss.append(loss)\n",
    "  trial_ids.append(int(trial['name'].split('/')[-1]))\n",
    "  \n",
    "print(max(trial_ids), 'total studies')\n",
    "print(len(trial_ids), 'valid studies')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "matplotlib.rcParams.update({'font.size': 16})\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "plt.ylim(min(metrics_loss), 0.07)\n",
    "ax.plot(trial_ids, metrics_loss)\n",
    "plt.xlabel(\"validation loss\")\n",
    "sorted_metrics = sorted(metrics_loss, reverse=True)\n",
    "\n",
    "best = 0\n",
    "for i, trial_id in enumerate(trial_ids):\n",
    "  if (metrics_loss[i] <= sorted_metrics[-5]):\n",
    "    print(trial_ids[i], metrics_loss[i])\n",
    "  if (metrics_loss[i] <= metrics_loss[best]):\n",
    "    best = i\n",
    "\n",
    "print('Best trial:', trial_ids[best])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import pprint\n",
    "from astronet import models\n",
    "\n",
    "config = models.get_model_config('AstroCNNModel', config_name)\n",
    "\n",
    "for param in params[best]:\n",
    "  tune.map_param(config['hparams'], param, config['inputs'])\n",
    "\n",
    "print(tune.FLAGS.train_steps)\n",
    "pprint.pprint(config['hparams'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "import difflib\n",
    "import pprint\n",
    "from astronet import models\n",
    "\n",
    "config1 = models.get_model_config('AstroCNNModel', config_name)\n",
    "\n",
    "config2 = models.get_model_config('AstroCNNModel', config_name)\n",
    "for param in params[best]:\n",
    "  tune.map_param(config2['hparams'], param, config2.inputs)\n",
    "  \n",
    "pp = pprint.PrettyPrinter()\n",
    "print('\\n'.join(difflib.unified_diff(\n",
    "  pp.pformat(config1).split('\\n'), pp.pformat(config2).split('\\n'),\n",
    "  n=0\n",
    ")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```\n",
    "python astronet/tune.py --model=AstroCNNModel --config_name=final_alpha_1 --train_files=/mnt/tess/astronet/tfrecords-38-train/* --eval_files=/mnt/tess/astronet/tfrecords-38-val/* --train_steps=0 --tune_trials=1000 --client_secrets=../client_secrets.json --study_id=38_final_alpha_1_1\n",
    "```"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
